{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "martial-negative",
   "metadata": {
    "id": "martial-negative",
    "papermill": {
     "duration": 0.024232,
     "end_time": "2021-04-03T12:55:25.038730",
     "exception": false,
     "start_time": "2021-04-03T12:55:25.014498",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "# Sensitivity Analysis for Unobserved Confounder with DML.\n",
    "\n",
    "Here we experiment with sensitivity analysis under unobserved confounding, both manually and with [sensitivity.py](https://github.com/vsyrgkanis/dml_sensitivity_python)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "criminal-workplace",
   "metadata": {
    "id": "criminal-workplace",
    "papermill": {
     "duration": 0.019939,
     "end_time": "2021-04-03T12:55:25.120184",
     "exception": false,
     "start_time": "2021-04-03T12:55:25.100245",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Partially Linear SEM\n",
    "\n",
    "Consider the SEM\n",
    "\\begin{eqnarray*}\n",
    "Y & := & \\alpha D + \\delta A + f_Y(X) + \\epsilon_Y,  \\\\\n",
    "D & := & \\gamma A + f_D(X) + \\epsilon_D, \\\\\n",
    "A & : =  & f_A(X) + \\epsilon_A, \\\\\n",
    "X & := &  \\epsilon_X,\n",
    "\\end{eqnarray*}\n",
    "where, conditional on $X$, $\\epsilon_Y, \\epsilon_D, \\epsilon_A$ are both mean zero\n",
    "and mutually uncorrelated. We further normalize\n",
    "$$\n",
    "E[\\epsilon_A^2] =1.\n",
    "$$\n",
    "The key structural\n",
    "parameter is $\\alpha$: $$\\alpha = \\partial_d Y(d)$$\n",
    "where $$Y(d) := (Y: do (D=d)).$$\n",
    "\n",
    "To give context to our example, we can interpret $Y$ as earnings,\n",
    "$D$ as education, $A$ as ability, and $X$ as a set of observed background variables. In this example, we can interpret $\\alpha$ as the returns to schooling.\n",
    "\n",
    "We start by applying the partialling out operator to get rid of the $X$'s in all of the equations. Define the partialling out operation of any random vector $V$ with respect to another random vector $X$ as the residual that is left after subtracting the best predictor of $V$ given $X$:\n",
    "$$\\tilde V = V - E [V \\mid X].$$  \n",
    "If $f$'s are linear, we can replace $E [V \\mid X]$\n",
    "by linear projection.  After partialling out, we have a simplified system:\n",
    "\\begin{eqnarray*}\n",
    "\\tilde Y & := & \\alpha \\tilde D + \\delta \\tilde A + \\epsilon_Y,  \\\\\n",
    "\\tilde D & := & \\gamma \\tilde A + \\epsilon_D, \\\\\n",
    "\\tilde A & : = & \\epsilon_A,\n",
    "\\end{eqnarray*}\n",
    "where $\\epsilon_Y$, $\\epsilon_D$, and $\\epsilon_A$ are uncorrelated.\n",
    "\n",
    "Then the projection of $\\tilde Y$ on $\\tilde D$ recovers\n",
    "$$\n",
    "\\beta = E [\\tilde Y \\tilde D]/ E [\\tilde D^2] = \\alpha +  \\phi,\n",
    "$$\n",
    "where\n",
    "$$\n",
    "\\phi =  \\delta \\gamma/ E \\left[(\\gamma^2 + \\epsilon^2_D)\\right],\n",
    "$$\n",
    "is the omitted confounder bias or omitted variable bias.\n",
    "\n",
    "The formula follows from inserting the expression for $\\tilde D$ into the definition of $\\beta$ and then simplifying the resulting expression using the assumptions on the $\\epsilon$'s.\n",
    "\n",
    "We can use this formula to bound $\\phi$ directly by making assumptions on the size of $\\delta$\n",
    "and $\\gamma$.  An alternative approach can be based on the following characterization,\n",
    "based on partial $R^2$'s.  This characterization essentially follows\n",
    "from Cinelli and Hazlett, with the slight difference that we have adapted\n",
    "the result to the partially linear model.\n",
    "\n",
    "*Theorem* [Omitted Confounder Bias in Terms of Partial $R^2$'s]\n",
    "\n",
    "In the partially linear SEM setting above,\n",
    "$$\n",
    "\\phi^2 = \\frac{R^2_{\\tilde Y \\sim \\tilde A \\mid \\tilde D} R^2_{\\tilde D \\sim \\tilde A} }{ (1 - R^2_{\\tilde D \\sim \\tilde A}) } \\\n",
    "\\frac{E \\left[ (\\tilde Y - \\beta \\tilde D)^2 \\right] }{E \\left[ ( \\tilde D )^2 \\right]},\n",
    "$$\n",
    "where $R^2_{V \\sim W \\mid X}$ denotes the population $R^2$ in the linear regression of $V$ on $W$, after partialling out linearly $X$ from $V$ and $W$.\n",
    "\n",
    "\n",
    "Therefore, if we place bounds on how much of the variation in $\\tilde Y$ and in $\\tilde D$\n",
    "the unobserved confounder $\\tilde A$ is able to explain, we can bound the omitted confounder bias by $$\\sqrt{\\phi^2}.$$\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "continuous-marshall",
   "metadata": {
    "id": "continuous-marshall",
    "papermill": {
     "duration": 0.020014,
     "end_time": "2021-04-03T12:55:25.160190",
     "exception": false,
     "start_time": "2021-04-03T12:55:25.140176",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "# Empirical Example\n",
    "\n",
    "We consider an empirical example based on data surrounding the Darfur war. Specifically, we are interested in the effect of having experienced direct war violence on attitudes towards peace. Data is described here\n",
    "https://cran.r-project.org/web/packages/sensemakr/vignettes/sensemakr.html\n",
    "\n",
    "The main outcome is attitude towards peace -- the peacefactor.\n",
    "The key variable of interest is whether the responders were directly harmed (directlyharmed).\n",
    "We want to know if being directly harmed in the conflict causes people to support peace-enforcing measures.\n",
    "The measured confounders include female indicator, age, farmer, herder, voted in the past, and household size.\n",
    "There is also a village indicator, which we will treat as fixed effect and partial it out before conducting\n",
    "the analysis. The standard errors will be clustered at the village level."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "oW_mOo_wcpmV",
   "metadata": {
    "id": "oW_mOo_wcpmV"
   },
   "source": [
    "\n",
    "## Outline\n",
    "\n",
    "We will:\n",
    "- mimic the partialling out procedure with machine learning tools;\n",
    "- invoking sensitivity.py to compute $\\phi^2$ and plot sensitivity results.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aPAq16YXlcNx",
   "metadata": {
    "id": "aPAq16YXlcNx"
   },
   "outputs": [],
   "source": [
    "# Import relevant packages\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.model_selection import cross_val_predict, KFold\n",
    "from sklearn.pipeline import make_pipeline\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.ensemble import RandomForestRegressor\n",
    "import patsy\n",
    "import warnings\n",
    "from sklearn.base import BaseEstimator\n",
    "import statsmodels.formula.api as smf\n",
    "warnings.simplefilter('ignore')\n",
    "np.random.seed(1234)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "qOOPGSB0lWM9",
   "metadata": {
    "id": "qOOPGSB0lWM9"
   },
   "outputs": [],
   "source": [
    "file = \"https://raw.githubusercontent.com/CausalAIBook/MetricsMLNotebooks/main/data/darfur.csv\"\n",
    "data = pd.read_csv(file)\n",
    "data.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "hidden-packing",
   "metadata": {
    "id": "hidden-packing",
    "papermill": {
     "duration": 0.021289,
     "end_time": "2021-04-03T12:55:38.319389",
     "exception": false,
     "start_time": "2021-04-03T12:55:38.298100",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Preprocessing\n",
    "Take out village fixed effects and run basic linear analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "authorized-transformation",
   "metadata": {
    "id": "authorized-transformation",
    "papermill": {
     "duration": 2.339638,
     "end_time": "2021-04-03T12:55:40.680306",
     "exception": false,
     "start_time": "2021-04-03T12:55:38.340668",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Get rid of village fixed effects\n",
    "peacefactorR = smf.ols(formula=\"peacefactor ~ village\", data=data).fit().resid\n",
    "directlyharmedR = smf.ols(formula=\"directlyharmed ~ village\", data=data).fit().resid\n",
    "femaleR = smf.ols(formula=\"female ~ village\", data=data).fit().resid\n",
    "ageR = smf.ols(formula=\"age ~ village\", data=data).fit().resid\n",
    "farmerR = smf.ols(formula=\"farmer_dar ~ village\", data=data).fit().resid\n",
    "herderR = smf.ols(formula=\"herder_dar ~ village\", data=data).fit().resid\n",
    "pastvotedR = smf.ols(formula=\"pastvoted ~ village\", data=data).fit().resid\n",
    "hhsizeR = smf.ols(formula=\"hhsize_darfur ~ village\", data=data).fit().resid\n",
    "\n",
    "# Preliminary linear model analysis\n",
    "model1 = smf.ols(formula=\"peacefactorR ~ directlyharmedR + femaleR + ageR + farmerR + herderR + pastvotedR + hhsizeR\",\n",
    "                 data=data).fit(cov_type='cluster', cov_kwds={'groups': data['village']})\n",
    "print(model1.summary())\n",
    "\n",
    "# Here we are clustering standard errors at the village level\n",
    "model2 = smf.ols(formula=\"peacefactorR ~ femaleR + ageR + farmerR + herderR + pastvotedR + hhsizeR\",\n",
    "                 data=data).fit(cov_type='cluster', cov_kwds={'groups': data['village']})\n",
    "print(model2.summary())\n",
    "\n",
    "model3 = smf.ols(formula=\"directlyharmedR ~ femaleR + ageR + farmerR + herderR + pastvotedR + hhsizeR\",\n",
    "                 data=data).fit(cov_type='cluster', cov_kwds={'groups': data['village']})\n",
    "print(model3.summary())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "careful-dollar",
   "metadata": {
    "id": "careful-dollar",
    "papermill": {
     "duration": 0.041148,
     "end_time": "2021-04-03T12:55:40.762964",
     "exception": false,
     "start_time": "2021-04-03T12:55:40.721816",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Lasso for partialling out controls"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "YHtKZ44_inRb",
   "metadata": {
    "id": "YHtKZ44_inRb"
   },
   "source": [
    "Run the following commands to install hdmpy for rigorous lasso:\n",
    "\n",
    "```\n",
    "!pip install multiprocess\n",
    "!git clone https://github.com/maxhuppertz/hdmpy.git\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8TpwVjQeilSu",
   "metadata": {
    "id": "8TpwVjQeilSu"
   },
   "outputs": [],
   "source": [
    "!pip install multiprocess\n",
    "!git clone https://github.com/maxhuppertz/hdmpy.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "EgDePhZViw2-",
   "metadata": {
    "id": "EgDePhZViw2-"
   },
   "outputs": [],
   "source": [
    "import hdmpy\n",
    "\n",
    "\n",
    "class RLasso(BaseEstimator):\n",
    "\n",
    "    def __init__(self, *, post=True):\n",
    "        self.post = post\n",
    "\n",
    "    def fit(self, X, y):\n",
    "        self.rlasso_ = hdmpy.rlasso(X, y, post=self.post)\n",
    "        return self\n",
    "\n",
    "    def predict(self, X):\n",
    "        return np.array(X) @ np.array(self.rlasso_.est['beta']).flatten() + np.array(self.rlasso_.est['intercept'])\n",
    "\n",
    "\n",
    "def lasso_model():\n",
    "    return RLasso(post=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "TFKiFOk2ILI-",
   "metadata": {
    "id": "TFKiFOk2ILI-"
   },
   "outputs": [],
   "source": [
    "Z = np.column_stack((femaleR, ageR, farmerR, herderR, pastvotedR, hhsizeR))\n",
    "Z = pd.DataFrame(Z, columns=['femaleR', 'ageR', 'farmerR', 'herderR', 'pastvotedR', 'hhsizeR'])\n",
    "# Interactions of 3 degrees\n",
    "controls = patsy.dmatrix('0 + (femaleR + ageR + farmerR + herderR + pastvotedR + hhsizeR)**3',\n",
    "                         Z, return_type='dataframe')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "hdQ9lV5eF5cD",
   "metadata": {
    "id": "hdQ9lV5eF5cD"
   },
   "outputs": [],
   "source": [
    "resY = peacefactorR - lasso_model().fit(controls, peacefactorR).predict(controls)\n",
    "resD = directlyharmedR - lasso_model().fit(controls, directlyharmedR).predict(controls)\n",
    "print((\"Controls explain the following fraction of variance of Outcome\", 1 - np.var(resY) / np.var(peacefactorR)))\n",
    "print((\"Controls explain the following fraction of variance of Treatment\", 1 - np.var(resD) / np.var(directlyharmedR)))\n",
    "\n",
    "dml_data = pd.DataFrame({'resY': resY, 'resD': resD, 'village': data['village']})\n",
    "dml_model = smf.ols(formula=\"resY ~ resD\", data=dml_data).fit(cov_type='cluster',\n",
    "                                                              cov_kwds={'groups': dml_data['village']})\n",
    "print(dml_model.summary())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "built-enlargement",
   "metadata": {
    "id": "built-enlargement",
    "papermill": {
     "duration": 0.02335,
     "end_time": "2021-04-03T12:55:41.169602",
     "exception": false,
     "start_time": "2021-04-03T12:55:41.146252",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Manual Bias Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "respective-sister",
   "metadata": {
    "id": "respective-sister",
    "papermill": {
     "duration": 0.380639,
     "end_time": "2021-04-03T12:55:41.573999",
     "exception": false,
     "start_time": "2021-04-03T12:55:41.193360",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Main estimate\n",
    "beta = dml_model.params[1]\n",
    "\n",
    "# Hypothetical values of partial R2s\n",
    "R2_YC = 0.16\n",
    "R2_DC = 0.01\n",
    "\n",
    "# Elements of the bias equation\n",
    "kappa = (R2_YC * R2_DC) / (1 - R2_DC)\n",
    "variance_ratio = np.mean(dml_model.resid**2) / np.mean(resD**2)\n",
    "\n",
    "# Compute square bias\n",
    "BiasSq = kappa * variance_ratio\n",
    "\n",
    "# Compute absolute value of the bias\n",
    "print(\"absolute value of the bias:\", np.sqrt(BiasSq))\n",
    "\n",
    "# Plotting\n",
    "gridR2_DC = np.arange(0, 0.301, 0.001)\n",
    "gridR2_YC = kappa * (1 - gridR2_DC) / gridR2_DC\n",
    "gridR2_YC = np.where(gridR2_YC > 1, 1, gridR2_YC)\n",
    "\n",
    "plt.plot(gridR2_DC, gridR2_YC, color='red')\n",
    "plt.xlabel('Partial R2 of Treatment with Confounder')\n",
    "plt.ylabel('Partial R2 of Outcome with Confounder')\n",
    "plt.title(f'Combo of R2 such that |Bias| < {np.round(np.sqrt(BiasSq), decimals=4)}')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "sorted-hands",
   "metadata": {
    "id": "sorted-hands",
    "papermill": {
     "duration": 0.025659,
     "end_time": "2021-04-03T12:55:41.626309",
     "exception": false,
     "start_time": "2021-04-03T12:55:41.600650",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Bias Analysis with some automated functions.\n",
    "\n",
    "We now automate the DML process and pass the estimates to functions to automate the sensitivity analysis. This is done in the R package sensmakr, which does not exist in Python."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3v-cR2kWpUQN",
   "metadata": {
    "id": "3v-cR2kWpUQN"
   },
   "outputs": [],
   "source": [
    "def dml(X, D, y, modely, modeld, *, nfolds, classifier=False, cluster=True, clu=None):\n",
    "    '''\n",
    "    DML for the Partially Linear Model setting with cross-fitting\n",
    "\n",
    "    Input\n",
    "    -----\n",
    "    X: the controls\n",
    "    D: the treatment\n",
    "    y: the outcome\n",
    "    modely: the ML model for predicting the outcome y\n",
    "    modeld: the ML model for predicting the treatment D\n",
    "    nfolds: the number of folds in cross-fitting\n",
    "    classifier: bool, whether the modeld is a classifier or a regressor\n",
    "\n",
    "    clu: df column to cluster by\n",
    "    cluster: bool, whether to use clustered standard errors\n",
    "\n",
    "    Output\n",
    "    ------\n",
    "    point: the point estimate of the treatment effect of D on y\n",
    "    stderr: the standard error of the treatment effect\n",
    "    yhat: the cross-fitted predictions for the outcome y\n",
    "    Dhat: the cross-fitted predictions for the treatment D\n",
    "    resy: the outcome residuals\n",
    "    resD: the treatment residuals\n",
    "    epsilon: the final residual-on-residual OLS regression residual\n",
    "    '''\n",
    "\n",
    "    if nfolds > 1:\n",
    "        cv = KFold(n_splits=nfolds, shuffle=True, random_state=123)  # shuffled k-folds\n",
    "        yhat = cross_val_predict(modely, X, y, cv=cv, n_jobs=-1)  # out-of-fold predictions for y\n",
    "        # out-of-fold predictions for D\n",
    "        # use predict or predict_proba dependent on classifier or regressor for D\n",
    "        if classifier:\n",
    "            Dhat = cross_val_predict(modeld, X, D, cv=cv, method='predict_proba', n_jobs=-1)[:, 1]\n",
    "        else:\n",
    "            Dhat = cross_val_predict(modeld, X, D, cv=cv, n_jobs=-1)\n",
    "    elif nfolds == -1:\n",
    "        yhat = modely.fit(X, y).predict(X)\n",
    "        if classifier:\n",
    "            Dhat = modeld.fit(X, D).predict_proba(X)\n",
    "        else:\n",
    "            Dhat = modeld.fit(X, D).predict(X)\n",
    "\n",
    "    # calculate outcome and treatment residuals\n",
    "    resy = y - yhat\n",
    "    resD = D - Dhat\n",
    "\n",
    "    if cluster:\n",
    "        # final stage ols clustered\n",
    "        dml_data = pd.DataFrame({'resY': resY, 'resD': resD, 'cluster': clu})\n",
    "    else:\n",
    "        # final stage ols nonclustered\n",
    "        dml_data = pd.DataFrame({'resY': resY, 'resD': resD})\n",
    "\n",
    "    if cluster:\n",
    "        # clustered standard errors\n",
    "        ols_mod = smf.ols(formula='resY ~ 1 + resD', data=dml_data)\n",
    "        ols_mod = ols_mod.fit(cov_type='cluster', cov_kwds={\"groups\": dml_data['cluster']})\n",
    "    else:\n",
    "        # regular ols\n",
    "        ols_mod = smf.ols(formula='resY ~ 1 + resD', data=dml_data).fit()\n",
    "\n",
    "    point = ols_mod.params[1]\n",
    "    stderr = ols_mod.bse[1]\n",
    "    epsilon = ols_mod.resid\n",
    "\n",
    "    return point, stderr, yhat, Dhat, resy, resD, epsilon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77qZoF5NpbA8",
   "metadata": {
    "id": "77qZoF5NpbA8"
   },
   "outputs": [],
   "source": [
    "def summary(point, stderr, yhat, Dhat, resy, resD, epsilon, X, D, y, *, name):\n",
    "    '''\n",
    "    Convenience summary function that takes the results of the DML function\n",
    "    and summarizes several estimation quantities and performance metrics.\n",
    "    '''\n",
    "    return pd.DataFrame({'estimate': point,  # point estimate\n",
    "                         'stderr': stderr,  # standard error\n",
    "                         'lower': point - 1.96 * stderr,  # lower end of 95% confidence interval\n",
    "                         'upper': point + 1.96 * stderr,  # upper end of 95% confidence interval\n",
    "                         'rmse y': np.sqrt(np.mean(resy**2)),  # RMSE of model that predicts outcome y\n",
    "                         'rmse D': np.sqrt(np.mean(resD**2))  # RMSE of model that predicts treatment D\n",
    "                         }, index=[name])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "t-g4IiwQi_x0",
   "metadata": {
    "id": "t-g4IiwQi_x0"
   },
   "outputs": [],
   "source": [
    "def dml_sensitivity_bounds_single(res, eta_ysq, eta_asq, alpha=None, inds=None):\n",
    "    ''' Sensitivity analysis, specialized for the partially linear DML moment\n",
    "    E[(yres - theta * Tres) * Tres]. `est` is a `LinearDML` estimator fitted\n",
    "    with `cache_values=True` so that residuals are being stored after fitting.\n",
    "    '''\n",
    "    if inds is None:\n",
    "        inds = np.arange(res[0].shape[0])\n",
    "    yres, Tres = res\n",
    "    yres, Tres = yres[inds], Tres[inds]\n",
    "    nusq = np.mean(Tres ** 2)\n",
    "    theta = np.mean(yres * Tres) / nusq\n",
    "    sigmasq = np.mean((yres - Tres * theta)**2)\n",
    "    S = np.sqrt(sigmasq / nusq)\n",
    "    Casq = eta_asq / (1 - eta_asq)\n",
    "    Cgsq = eta_ysq\n",
    "    error = S * np.sqrt(Casq * Cgsq)\n",
    "\n",
    "    if alpha is None:\n",
    "        return theta - error, theta + error\n",
    "\n",
    "    psi_theta = (yres - Tres * theta) * Tres / nusq\n",
    "    psi_sigmasq = (yres - Tres * theta)**2 - sigmasq\n",
    "    psi_nusq = Tres**2 - nusq\n",
    "\n",
    "    phi_plus = psi_theta\n",
    "    phi_plus += (np.sqrt(Casq * Cgsq) / (2 * S)) * (-(sigmasq / (nusq**2)) * psi_nusq + (1 / nusq) * psi_sigmasq)\n",
    "    stderr_plus = np.sqrt(np.mean(phi_plus**2) / phi_plus.shape[0])\n",
    "\n",
    "    phi_minus = psi_theta\n",
    "    phi_minus -= (np.sqrt(Casq * Cgsq) / (2 * S)) * (-(sigmasq / (nusq**2)) * psi_nusq + (1 / nusq) * psi_sigmasq)\n",
    "    stderr_minus = np.sqrt(np.mean(phi_minus**2) / phi_minus.shape[0])\n",
    "    return theta - error, stderr_minus, theta + error, stderr_plus\n",
    "\n",
    "\n",
    "def dml_sensitivity_contours_single(res, a_upper, y_upper, alpha=None, inds=None):\n",
    "    ''' Specialized for linear DML. Sensitivity bounds contour plots based on a many trained doubly robust\n",
    "    average moment models. If `alpha` is float, incorporates sampling uncertainty\n",
    "    at the `alpha` level. Sensitivity parameter `eta_ysq` ranges in `[0, y_upper]`\n",
    "    and parameter `eta_asq` ranges in `[0, a_upper]`.\n",
    "    '''\n",
    "    xlist = np.linspace(0, a_upper, 100)\n",
    "    ylist = np.linspace(0, y_upper, 100)\n",
    "    X, Y = np.meshgrid(xlist, ylist)\n",
    "    Zlower = np.zeros(X.shape)\n",
    "    Zupper = np.zeros(X.shape)\n",
    "    for itx in np.arange(X.shape[1]):\n",
    "        for ity in np.arange(X.shape[0]):\n",
    "            if alpha is None:\n",
    "                l, u, = dml_sensitivity_bounds_single(res, Y[ity, itx], X[ity, itx], alpha=alpha, inds=inds)\n",
    "            else:\n",
    "                l, _, u, _ = dml_sensitivity_bounds_single(res, Y[ity, itx], X[ity, itx], alpha=alpha, inds=inds)\n",
    "            Zlower[ity, itx] = l\n",
    "            Zupper[ity, itx] = u\n",
    "\n",
    "    return X, Y, Zlower, Zupper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "SoSr-sqVpc4d",
   "metadata": {
    "id": "SoSr-sqVpc4d"
   },
   "outputs": [],
   "source": [
    "# DML with RLasso\n",
    "modely = make_pipeline(StandardScaler(), lasso_model())\n",
    "modeld = make_pipeline(StandardScaler(), lasso_model())\n",
    "\n",
    "# Run DML model with no crossfitting (change nfolds to >1 to use crossfitting)\n",
    "result_lasso = dml(controls, directlyharmedR, peacefactorR, modely, modeld,\n",
    "                   nfolds=-1, classifier=False, clu=data['village'], cluster=True)\n",
    "table_lasso = summary(*result_lasso, controls, directlyharmedR, peacefactorR, name='RLasso')\n",
    "print(table_lasso)\n",
    "\n",
    "resY_lasso, resD_lasso = result_lasso[4], result_lasso[5]\n",
    "print((\"Controls explain the following fraction of variance of Outcome\",\n",
    "       max(1 - np.var(resY_lasso) / np.var(peacefactorR), 0)))\n",
    "print((\"Controls explain the following fraction of variance of Treatment\",\n",
    "       max(1 - np.var(resD_lasso) / np.var(directlyharmedR), 0)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "HH51ejTcqgT2",
   "metadata": {
    "id": "HH51ejTcqgT2"
   },
   "outputs": [],
   "source": [
    "res = [resY_lasso, resD_lasso]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aW1ySANwjylk",
   "metadata": {
    "id": "aW1ySANwjylk"
   },
   "outputs": [],
   "source": [
    "print(\"[beta - phi, beta + phi]: \", dml_sensitivity_bounds_single(res, 0.16, 0.01))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "BO__qNSje6lS",
   "metadata": {
    "id": "BO__qNSje6lS"
   },
   "outputs": [],
   "source": [
    "eta_asq, eta_ysq, lower, upper = dml_sensitivity_contours_single(res, 0.4, 0.4, alpha=0.05)\n",
    "# Sensitivity parameter `eta_ysq` = R2_{Y ~ A|D} ranges in `[0, y_upper]`\n",
    "# and `eta_asq` = R2_{D ~ A} ranges in `[0, a_upper]`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "Z972yBSSk6UX",
   "metadata": {
    "id": "Z972yBSSk6UX"
   },
   "outputs": [],
   "source": [
    "contours = plt.contour(eta_asq, eta_ysq, lower, 6, linestyles='-', colors='blue')\n",
    "plt.clabel(contours, inline=True, fontsize=8)\n",
    "plt.title('Lower limit')\n",
    "plt.xlabel('Partial R2 of confounder(s) with the treatment')\n",
    "plt.ylabel('Partial R2 of confounder(s) with the outcome')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "charged-mauritius",
   "metadata": {
    "id": "charged-mauritius",
    "papermill": {
     "duration": 0.030825,
     "end_time": "2021-04-03T12:55:42.286467",
     "exception": false,
     "start_time": "2021-04-03T12:55:42.255642",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Random Forest for partialling out"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "charitable-placement",
   "metadata": {
    "id": "charitable-placement",
    "papermill": {
     "duration": 0.030332,
     "end_time": "2021-04-03T12:55:42.347072",
     "exception": false,
     "start_time": "2021-04-03T12:55:42.316740",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "The following code does DML with clustered standard errors by village"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "U7XnplWkRd_P",
   "metadata": {
    "id": "U7XnplWkRd_P"
   },
   "outputs": [],
   "source": [
    "# DML with Random Forests. RFs don't require scaling but we do it for consistency\n",
    "modely = make_pipeline(StandardScaler(), RandomForestRegressor(n_estimators=100, min_samples_leaf=5, random_state=123))\n",
    "modeld = make_pipeline(StandardScaler(), RandomForestRegressor(n_estimators=100, min_samples_leaf=5, random_state=123))\n",
    "\n",
    "# Run DML model with nfolds folds of cross-fitting (computationally intensive)\n",
    "result_RF = dml(Z, directlyharmedR, peacefactorR, modely, modeld, nfolds=10,\n",
    "                classifier=False, clu=data['village'], cluster=True)\n",
    "table_RF = summary(*result_RF, Z, directlyharmedR, peacefactorR, name='RF')\n",
    "print(table_RF)\n",
    "\n",
    "resY_RF, resD_RF = result_RF[4], result_RF[5]\n",
    "print((\"Controls explain the following fraction of variance of Outcome\",\n",
    "       max(1 - np.var(resY_RF) / np.var(peacefactorR), 0)))\n",
    "print((\"Controls explain the following fraction of variance of Treatment\",\n",
    "       max(1 - np.var(resD_RF) / np.var(directlyharmedR), 0)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cUxDc1mYdMHH",
   "metadata": {
    "id": "cUxDc1mYdMHH"
   },
   "source": [
    "## Bias Analysis\n",
    "\n",
    "If we want reduce the uncertainty from sample splitting, we can re-run our $nfolds$ cross-fitting and aggregate with the median estimates. With 10 folds as above, there shouldn't be too much variance anyway, but we employ the following aggregation procedure anyway.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "V1tYIFMCeGbQ",
   "metadata": {
    "id": "V1tYIFMCeGbQ"
   },
   "outputs": [],
   "source": [
    "import scipy\n",
    "\n",
    "\n",
    "def dml_sensitivity_bounds(est_list, eta_ysq, eta_asq, alpha=None, inds=None):\n",
    "    if alpha is None:\n",
    "        lower, upper = zip(*[dml_sensitivity_bounds_single(est, eta_ysq, eta_asq, alpha=alpha, inds=inds)\n",
    "                             for est in est_list])\n",
    "        return np.median(lower), np.median(upper)\n",
    "    else:\n",
    "        lower, std_lower, upper, std_upper = zip(*[dml_sensitivity_bounds_single(est, eta_ysq, eta_asq,\n",
    "                                                                                 alpha=alpha, inds=inds)\n",
    "                                                   for est in est_list])\n",
    "        std_lower = np.array(std_lower)\n",
    "        std_upper = np.array(std_upper)\n",
    "        lower = np.median(lower) - scipy.stats.norm.ppf(1 - alpha) * np.sqrt(np.median(std_lower**2) + np.var(lower))\n",
    "        upper = np.median(upper) + scipy.stats.norm.ppf(1 - alpha) * np.sqrt(np.median(std_upper**2) + np.var(upper))\n",
    "        return lower, upper\n",
    "\n",
    "\n",
    "def dml_sensitivity_contours(est_list, a_upper, y_upper, alpha=None, inds=None):\n",
    "    ''' Specialized for linear DML. Sensitivity bounds contour plots based on a many trained doubly robust\n",
    "    average moment models. If `alpha` is float, incorporates sampling uncertainty\n",
    "    at the `alpha` level. Sensitivity parameter `eta_ysq` ranges in `[0, y_upper]`\n",
    "    and parameter `eta_asq` ranges in `[0, a_upper]`.\n",
    "    '''\n",
    "    xlist = np.linspace(0, a_upper, 100)\n",
    "    ylist = np.linspace(0, y_upper, 100)\n",
    "    X, Y = np.meshgrid(xlist, ylist)\n",
    "    Zlower = np.zeros(X.shape)\n",
    "    Zupper = np.zeros(X.shape)\n",
    "    for itx in np.arange(X.shape[1]):\n",
    "        for ity in np.arange(X.shape[0]):\n",
    "            l, u = dml_sensitivity_bounds(est_list, Y[ity, itx], X[ity, itx], alpha=alpha, inds=inds)\n",
    "            Zlower[ity, itx] = l\n",
    "            Zupper[ity, itx] = u\n",
    "\n",
    "    return X, Y, Zlower, Zupper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "EImLRgnb0peq",
   "metadata": {
    "id": "EImLRgnb0peq"
   },
   "outputs": [],
   "source": [
    "# redefine models without setting state -- random_state controls train_test_split downstream\n",
    "modely = make_pipeline(StandardScaler(), RandomForestRegressor(n_estimators=100, min_samples_leaf=5))\n",
    "modeld = make_pipeline(StandardScaler(), RandomForestRegressor(n_estimators=100, min_samples_leaf=5))\n",
    "res_RF_list = []\n",
    "for i in range(5):\n",
    "    result_RF = dml(Z, directlyharmedR, peacefactorR, modely, modeld, nfolds=10,\n",
    "                    classifier=False, clu=data['village'], cluster=True)\n",
    "    table_RF = summary(*result_RF, Z, directlyharmedR, peacefactorR, name='RF')\n",
    "    resY_RF, resD_RF = result_RF[4], result_RF[5]\n",
    "    res_RF = [resY_RF, resD_RF]\n",
    "    res_RF_list.append(res_RF)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "mNk1o3xBTTwr",
   "metadata": {
    "id": "mNk1o3xBTTwr"
   },
   "outputs": [],
   "source": [
    "print(\"[beta - phi, beta + phi]: \", dml_sensitivity_bounds(res_RF_list, 0.16, 0.01))\n",
    "alp = 0.05\n",
    "print(f\"With Sampling Uncertainty at the {alp} level:\", dml_sensitivity_bounds(res_RF_list, .16, .01, alpha=.05))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "obvious-there",
   "metadata": {
    "id": "obvious-there",
    "papermill": {
     "duration": 40.040643,
     "end_time": "2021-04-03T12:56:22.614312",
     "exception": false,
     "start_time": "2021-04-03T12:55:42.573669",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "alp = 0.05\n",
    "eta_asq, eta_ysq, lower, upper = dml_sensitivity_contours(res_RF_list, 0.4, 0.4, alpha=alp)\n",
    "contours = plt.contour(eta_asq, eta_ysq, lower, 6, linestyles='-', colors='blue')\n",
    "plt.clabel(contours, inline=True, fontsize=8)\n",
    "plt.title('Lower limit')\n",
    "plt.xlabel('Partial R2 of confounder(s) with the treatment')\n",
    "plt.ylabel('Partial R2 of confounder(s) with the outcome')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "JihrkBjEYcOG",
   "metadata": {
    "id": "JihrkBjEYcOG"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "provenance": [
    {
     "file_id": "1wimiNNG0vdxifPHM13U1ZrgOLWOeMPQO",
     "timestamp": 1702191168183
    },
    {
     "file_id": "https://github.com/CausalAIBook/MetricsMLNotebooks/blob/main/CM4/sensitivity-analysis-with-sensmakr-and-debiased-ml.ipynb",
     "timestamp": 1702189190347
    }
   ]
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "papermill": {
   "default_parameters": {},
   "duration": 60.897772,
   "end_time": "2021-04-03T12:56:22.764591",
   "environment_variables": {},
   "exception": null,
   "input_path": "__notebook__.ipynb",
   "output_path": "__notebook__.ipynb",
   "parameters": {},
   "start_time": "2021-04-03T12:55:21.866819",
   "version": "2.3.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
